{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Quora_BERT.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/eaishwa/quora-ques-pair-similarity/blob/master/Quora-QuesPairSim-BERT.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCAuDAQzKgxn",
        "colab_type": "text"
      },
      "source": [
        "# Using BERT model to find if two questions are similar - Quora Question Pair Dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BxaWlte6Kvs-",
        "colab_type": "text"
      },
      "source": [
        "This notebook gives step by step implementation of using the BERT model developed by google, to fine tune on the famous QQP dataset on kaggle. The performance on this dataset is reported by Google in the paper, but this is a guide to implement the fine tuning of the BERT model for a sentence pair classification problem. This guide shall be followed for similar downstream tasks provided you have the task specific labelled dataset.\n",
        "\n",
        "This notebook uses the pytorch hugging face implementation of BERT namely PyTorch-Transformers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZGSMLbacLtcq",
        "colab_type": "text"
      },
      "source": [
        "### Upload the train data from the source."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7UqtNkxxJKM0",
        "colab_type": "code",
        "colab": {
          "resources": {
            "http://localhost:8080/nbextensions/google.colab/files.js": {
              "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
              "ok": true,
              "headers": [
                [
                  "content-type",
                  "application/javascript"
                ]
              ],
              "status": 200,
              "status_text": ""
            }
          },
          "base_uri": "https://localhost:8080/",
          "height": 75
        },
        "outputId": "cc0b3083-5b51-4ef3-e88c-823a791168a2"
      },
      "source": [
        "from google.colab import files\n",
        "uploaded = files.upload()"
      ],
      "execution_count": 98,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "     <input type=\"file\" id=\"files-5a3da0de-8c05-4257-9dce-4e85d6f41722\" name=\"files[]\" multiple disabled />\n",
              "     <output id=\"result-5a3da0de-8c05-4257-9dce-4e85d6f41722\">\n",
              "      Upload widget is only available when the cell has been executed in the\n",
              "      current browser session. Please rerun this cell to enable.\n",
              "      </output>\n",
              "      <script src=\"/nbextensions/google.colab/files.js\"></script> "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Saving questions.csv to questions (3).csv\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rOJxCXLHL_5P",
        "colab_type": "text"
      },
      "source": [
        "### Installing the pytorch-transformers library and importing the necessary requirements."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TZDWWMXnJPSl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 783
        },
        "outputId": "8c0bc7cf-6370-44a6-ea12-85bb1609b401"
      },
      "source": [
        "!pip install pytorch-transformers\n",
        "!pip install pytorch-pretrained-bert pytorch-nlp\n",
        "\n",
        "import logging\n",
        "logging.basicConfig(level=logging.INFO)\n",
        "import torch\n",
        "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
        "from keras.preprocessing.sequence import pad_sequences\n",
        "from sklearn.model_selection import train_test_split\n",
        "from pytorch_pretrained_bert import BertTokenizer, BertConfig\n",
        "from pytorch_pretrained_bert import BertAdam, BertForSequenceClassification\n",
        "from tqdm import tqdm, trange\n",
        "import pandas as pd\n",
        "import io\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "% matplotlib inline"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pytorch-transformers in /usr/local/lib/python3.6/dist-packages (1.2.0)\n",
            "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (2019.8.19)\n",
            "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (1.9.220)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (1.16.5)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (4.28.1)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (2.21.0)\n",
            "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (0.1.83)\n",
            "Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (0.0.33)\n",
            "Requirement already satisfied: torch>=1.0.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-transformers) (1.1.0)\n",
            "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-transformers) (0.2.1)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-transformers) (0.9.4)\n",
            "Requirement already satisfied: botocore<1.13.0,>=1.12.220 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-transformers) (1.12.220)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-transformers) (2019.6.16)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-transformers) (2.8)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-transformers) (3.0.4)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-transformers) (1.24.3)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->pytorch-transformers) (7.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->pytorch-transformers) (1.12.0)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->pytorch-transformers) (0.13.2)\n",
            "Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.220->boto3->pytorch-transformers) (2.5.3)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.220->boto3->pytorch-transformers) (0.15.2)\n",
            "Requirement already satisfied: pytorch-pretrained-bert in /usr/local/lib/python3.6/dist-packages (0.6.2)\n",
            "Requirement already satisfied: pytorch-nlp in /usr/local/lib/python3.6/dist-packages (0.4.1)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (4.28.1)\n",
            "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (2019.8.19)\n",
            "Requirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.9.220)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.16.5)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (2.21.0)\n",
            "Requirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.1.0)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (from pytorch-nlp) (0.24.2)\n",
            "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (0.2.1)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (0.9.4)\n",
            "Requirement already satisfied: botocore<1.13.0,>=1.12.220 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (1.12.220)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (2019.6.16)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (3.0.4)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (1.24.3)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (2.8)\n",
            "Requirement already satisfied: pytz>=2011k in /usr/local/lib/python3.6/dist-packages (from pandas->pytorch-nlp) (2018.9)\n",
            "Requirement already satisfied: python-dateutil>=2.5.0 in /usr/local/lib/python3.6/dist-packages (from pandas->pytorch-nlp) (2.5.3)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.13.0,>=1.12.220->boto3->pytorch-pretrained-bert) (0.15.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil>=2.5.0->pandas->pytorch-nlp) (1.12.0)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "Using TensorFlow backend.\n",
            "INFO:pytorch_pretrained_bert.modeling:Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zn3MPlLCK3ZM",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "d0766eee-5709-4378-b88e-3fbbe8cf1ec6"
      },
      "source": [
        "#data = pd.read_csv(io.BytesIO(uploaded['questions.csv']))\n",
        "data = pd.read_csv('questions (3).csv')\n",
        "data = data.dropna()\n",
        "print(\"The percentage of non similar question pairs is : \")\n",
        "print(len(data[data['is_duplicate']==0].index)*100/len(data.index))\n",
        "print(\"The percentage of similar question pairs is : \")\n",
        "print(len(data[data['is_duplicate']==1].index)*100/len(data.index))\n"
      ],
      "execution_count": 65,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "The percentage of non similar question pairs is : \n",
            "63.07487609682749\n",
            "The percentage of similar question pairs is : \n",
            "36.92512390317251\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_L4YsySfMfhf",
        "colab_type": "text"
      },
      "source": [
        "### Sampling 5000 data points to enhance training speed and optimize memory usage. \n",
        "\n",
        "The original dataset consists of 404,348 samples."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DKez_G1oM7dR",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "9a9f5d24-b88c-4047-f251-cf35cc2e0fd9"
      },
      "source": [
        "# Randomly sample 5000 data points\n",
        "data = data.sample(n=5000)\n",
        "\n",
        "print(\"The percentage of non similar question pairs after sampling is : \")\n",
        "print(len(data[data['is_duplicate']==0].index)*100/len(data.index))\n",
        "print(\"The percentage of similar question pairs after sampling is : \")\n",
        "print(len(data[data['is_duplicate']==1].index)*100/len(data.index))\n",
        "\n",
        "# store the labels \n",
        "labels = data.is_duplicate.values"
      ],
      "execution_count": 66,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "The percentage of non similar question pairs after sampling is : \n",
            "62.76\n",
            "The percentage of similar question pairs after sampling is : \n",
            "37.24\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m0jQuaVnNUoX",
        "colab_type": "text"
      },
      "source": [
        "From the above, we see that the class percentages are fairly maintained even after sampling."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fyMFVNt9NeH5",
        "colab_type": "text"
      },
      "source": [
        "### Check the meta information about the dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YU8jJJj7-aE4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 208
        },
        "outputId": "c287ed00-528b-4b48-8656-31889bfdded6"
      },
      "source": [
        "data.info()"
      ],
      "execution_count": 67,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "<class 'pandas.core.frame.DataFrame'>\n",
            "Int64Index: 5000 entries, 63770 to 95141\n",
            "Data columns (total 6 columns):\n",
            "id              5000 non-null int64\n",
            "qid1            5000 non-null int64\n",
            "qid2            5000 non-null int64\n",
            "question1       5000 non-null object\n",
            "question2       5000 non-null object\n",
            "is_duplicate    5000 non-null int64\n",
            "dtypes: int64(4), object(2)\n",
            "memory usage: 273.4+ KB\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Af2WqCHNpaG",
        "colab_type": "text"
      },
      "source": [
        "### Prepare input data for BERT\n",
        "\n",
        "We need special data preparation methods for BERT. Following are the steps.\n",
        "\n",
        "1) Tokenize the sentences using the library. \n",
        "\n",
        "2) Append the string \"[CLS]\" to the beginning of the sentence and \"[SEP]\" to the end of the sentence. \n",
        "\n",
        "3) In case of sentence pairs, we need to append \"[SEP]\" at the end of the second sentence too. \n",
        "\n",
        "4) Obtain the input ids of the tokens from the output of the tokenizer. The model needs this to identify tokens uniquely.\n",
        "\n",
        "5) BERT accepts input sequences in fixed sizes such as 128, 256, 320, 384, 512. So we need to truncate larger sequences or pad smaller sequences with 0.\n",
        "\n",
        "6) A segment mask is to be specified to identify if the input is a single sentence or pair of sentences. Indicative values such 0 for first sentence and 1 for second sentence are used for this purpose.\n",
        "\n",
        "7) An attention mask is to be specified, to let the model know which are the tokens and which are the paddings we introduced in step 5. 1 indicates token and 0 indicates padding.\n",
        "\n",
        "This is the format in which the original BERT model was trained by google. So the users are also expected to follow the same for the best results."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hJTEYRFrKppD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 55
        },
        "outputId": "6ad33846-17c9-4a02-c222-e1d2d6ec13f5"
      },
      "source": [
        "# Load pre-trained model tokenizer (vocabulary)\n",
        "\n",
        "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
        "\n",
        "# function to tokenize and generate input ids for the tokens\n",
        "# returns a list of input ids\n",
        "\n",
        "def prep_data(ques1, ques2):\n",
        "  all_input_ids = []\n",
        "  \n",
        "  for (q1,q2) in zip(ques1, ques2):\n",
        "    \n",
        "    # first sentence is appended with [CLS] and [SEP] in the beginning and end\n",
        "    q1 = '[CLS] ' + q1 + ' [SEP] '\n",
        "    tokens = tokenizer.tokenize(q1)\n",
        "    \n",
        "    # 0 denotes first sentence\n",
        "    seg_ids = [0] * len(tokens)\n",
        "    \n",
        "    # second sentence is appended with [SEP] in the end\n",
        "    q2 = q2 + ' [SEP] '\n",
        "    tok_q2 = tokenizer.tokenize(q2)\n",
        "    \n",
        "    # seg ids is appended with 1 to denote second sentence\n",
        "    seg_ids += [1] * len(tok_q2)\n",
        "    \n",
        "    # first and second sentence tokens are appended together\n",
        "    tokens += tok_q2\n",
        "    \n",
        "    # input ids are generated for the tokens (one question pair)\n",
        "    input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
        "\n",
        "    # input ids are stored in a separate list\n",
        "    all_input_ids.append(input_ids)\n",
        "    \n",
        "  return all_input_ids\n",
        "\n",
        "\n",
        "all_input_ids = prep_data(data['question1'].values, data['question2'].values)"
      ],
      "execution_count": 68,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wQO9-B9-pO3H",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# set MAX_LEN as one of 128, 256, 320, 384, 512\n",
        "MAX_LEN = 128\n",
        "\n",
        "# Pad our input tokens\n",
        "pad_input_ids = pad_sequences(all_input_ids,\n",
        "                          maxlen=MAX_LEN, dtype=\"long\", truncating=\"post\", padding=\"post\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IHDVWbUQps5_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Create attention masks\n",
        "attention_masks = []\n",
        "\n",
        "# Create a mask of 1s for each token followed by 0s for padding\n",
        "for seq in pad_input_ids:\n",
        "  seq_mask = [float(i>0) for i in seq]\n",
        "  attention_masks.append(seq_mask)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9VOM6w_NUBgY",
        "colab_type": "text"
      },
      "source": [
        "### Check if GPU is available"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WVYAbpmhjTmk",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "d694b05d-8448-40af-e40e-e8886ae019fa"
      },
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "device_name = tf.test.gpu_device_name()\n",
        "if device_name != '/device:GPU:0':\n",
        "  raise SystemError('GPU device not found')\n",
        "print('Found GPU at: {}'.format(device_name))"
      ],
      "execution_count": 71,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Found GPU at: /device:GPU:0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j6AAXHf4UMBE",
        "colab_type": "text"
      },
      "source": [
        "### Obtain the BERT model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LiYzw4MzaJCz",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "d738a581-61f8-44b1-908b-0faeb1c96532"
      },
      "source": [
        "model = BertForSequenceClassification.from_pretrained(\"bert-base-uncased\", num_labels=2)\n",
        "model.cuda()"
      ],
      "execution_count": 72,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /root/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
            "INFO:pytorch_pretrained_bert.modeling:extracting archive file /root/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp0h0cvmjx\n",
            "INFO:pytorch_pretrained_bert.modeling:Model config {\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"max_position_embeddings\": 512,\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"type_vocab_size\": 2,\n",
            "  \"vocab_size\": 30522\n",
            "}\n",
            "\n",
            "INFO:pytorch_pretrained_bert.modeling:Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']\n",
            "INFO:pytorch_pretrained_bert.modeling:Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "BertForSequenceClassification(\n",
              "  (bert): BertModel(\n",
              "    (embeddings): BertEmbeddings(\n",
              "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
              "      (position_embeddings): Embedding(512, 768)\n",
              "      (token_type_embeddings): Embedding(2, 768)\n",
              "      (LayerNorm): BertLayerNorm()\n",
              "      (dropout): Dropout(p=0.1)\n",
              "    )\n",
              "    (encoder): BertEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (1): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (2): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (3): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (4): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (5): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (6): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (7): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (8): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (9): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (10): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "        (11): BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): BertLayerNorm()\n",
              "              (dropout): Dropout(p=0.1)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): BertLayerNorm()\n",
              "            (dropout): Dropout(p=0.1)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "    (pooler): BertPooler(\n",
              "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (activation): Tanh()\n",
              "    )\n",
              "  )\n",
              "  (dropout): Dropout(p=0.1)\n",
              "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 72
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GFRecISyfBIw",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "7a987a18-bc78-4c50-e7fa-02d8f4087b66"
      },
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "n_gpu = torch.cuda.device_count()\n",
        "torch.cuda.get_device_name(0)"
      ],
      "execution_count": 73,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'Tesla K80'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 73
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z5eOUhfvUbLF",
        "colab_type": "text"
      },
      "source": [
        "### Run the model\n",
        "\n",
        "Split the input data into train and validation sets. I keep 20% of the data for validation. Convert all the inputs to tensors which is format for this library."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y9xG7t3PjsTi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Use train_test_split to split our data into train and validation sets for training\n",
        "\n",
        "train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(pad_input_ids, labels, \n",
        "                                                            random_state=2018, test_size=0.2)\n",
        "train_masks, validation_masks, _, _ = train_test_split(attention_masks, pad_input_ids,\n",
        "                                             random_state=2018, test_size=0.2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mASzBDUdnz-A",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_inputs = torch.tensor(train_inputs)\n",
        "validation_inputs = torch.tensor(validation_inputs)\n",
        "train_labels = torch.tensor(train_labels)\n",
        "validation_labels = torch.tensor(validation_labels)\n",
        "train_masks = torch.tensor(train_masks)\n",
        "validation_masks = torch.tensor(validation_masks)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jHJgGF6sU6pe",
        "colab_type": "text"
      },
      "source": [
        "The DataLoader allows us to get only that particular batch neeed for that epoch. This helps save lot of memory because we wont be loading the entire data in memory during training."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-JC-gf0HpCCW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Select a batch size for training. For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32\n",
        "batch_size = 32\n",
        "\n",
        "# Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop, \n",
        "# with an iterator the entire dataset does not need to be loaded into memory\n",
        "\n",
        "train_data = TensorDataset(train_inputs, train_masks, train_labels)\n",
        "train_sampler = RandomSampler(train_data)\n",
        "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
        "\n",
        "validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)\n",
        "validation_sampler = SequentialSampler(validation_data)\n",
        "validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7On2M8vvVi3d",
        "colab_type": "text"
      },
      "source": [
        "### Get the same hyperparameters as the model and define the Adam Optimizer"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QvUbptQbqi6i",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "param_optimizer = list(model.named_parameters())\n",
        "no_decay = ['bias', 'gamma', 'beta']\n",
        "optimizer_grouped_parameters = [\n",
        "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
        "     'weight_decay_rate': 0.01},\n",
        "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n",
        "     'weight_decay_rate': 0.0}\n",
        "]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6lEhkdsdqv9Y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "9fe94e07-9466-4e29-f25c-b63caa356f26"
      },
      "source": [
        "# This variable contains all of the hyperparemeter information our training loop needs\n",
        "optimizer = BertAdam(optimizer_grouped_parameters,\n",
        "                     lr=2e-5,\n",
        "                     warmup=.1)"
      ],
      "execution_count": 78,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING:pytorch_pretrained_bert.optimization:t_total value of -1 results in schedule not being applied\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5LC6lT5iVqpr",
        "colab_type": "text"
      },
      "source": [
        "### Define a method to compute accuracy"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7Tbm46U_1G-M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Function to calculate the accuracy of our predictions vs labels\n",
        "def accuracy(preds, labels):\n",
        "    pred_flat = np.argmax(preds, axis=1).flatten()\n",
        "    labels_flat = labels.flatten()\n",
        "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SvgKNmNTV6-x",
        "colab_type": "text"
      },
      "source": [
        "### Run the neural network model\n",
        "\n",
        "1) Obtain the batch of data to compute gradient from the data loader that we created above.\n",
        "\n",
        "2) Do not store the gradients as they aren't needed in this case.\n",
        "\n",
        "3) Make a forward pass in the network followed by backward pass (backpropagation).\n",
        "\n",
        "4) Update network parameters.\n",
        "\n",
        "5) Track the loss function.\n",
        "\n",
        "6) Run predictions on the validation set and record accuracy.\n",
        "\n",
        "7) Run steps 1 to 6 for the number of epochs specified."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J9DWJMUNq5Zx",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        },
        "outputId": "3d7ef651-5264-4aba-ae9c-be1de0d1360f"
      },
      "source": [
        "train_loss_set = []\n",
        "\n",
        "# Number of training epochs (authors recommend between 2 and 4)\n",
        "epochs = 2\n",
        "\n",
        "# trange is a tqdm wrapper around the normal python range\n",
        "for _ in trange(epochs, desc=\"Epoch\"):\n",
        "  \n",
        "  \n",
        "  # Training\n",
        "  \n",
        "  # Set our model to training mode (as opposed to evaluation mode)\n",
        "  model.train()\n",
        "  \n",
        "  # Tracking variables\n",
        "  tr_loss = 0\n",
        "  nb_tr_examples, nb_tr_steps = 0, 0\n",
        "  \n",
        "  # Train the data for one epoch\n",
        "  for step, batch in enumerate(train_dataloader):\n",
        "    \n",
        "    # Add batch to GPU\n",
        "    batch = tuple(t.to(device) for t in batch)\n",
        "    \n",
        "    # Unpack the inputs from our dataloader\n",
        "    b_input_ids, b_input_mask, b_labels = batch\n",
        "    \n",
        "    # Clear out the gradients (by default they accumulate)\n",
        "    optimizer.zero_grad()\n",
        "    \n",
        "    # Forward pass\n",
        "    loss = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)\n",
        "    train_loss_set.append(loss.item())    \n",
        "    \n",
        "    # Backward pass\n",
        "    loss.backward()\n",
        "    \n",
        "    # Update parameters and take a step using the computed gradient\n",
        "    optimizer.step()\n",
        "    \n",
        "    \n",
        "    # Update tracking variables\n",
        "    tr_loss += loss.item()\n",
        "    nb_tr_examples += b_input_ids.size(0)\n",
        "    nb_tr_steps += 1\n",
        "\n",
        "  print(\"Train loss: {}\".format(tr_loss/nb_tr_steps))\n",
        "    \n",
        "    \n",
        "  # Validation\n",
        "\n",
        "  # Put model in evaluation mode to evaluate loss on the validation set\n",
        "  model.eval()\n",
        "\n",
        "  # Tracking variables \n",
        "  eval_loss, eval_accuracy = 0, 0\n",
        "  nb_eval_steps, nb_eval_examples = 0, 0\n",
        "\n",
        "  # Evaluate data for one epoch\n",
        "  for batch in validation_dataloader:\n",
        "    \n",
        "    # Add batch to GPU\n",
        "    batch = tuple(t.to(device) for t in batch)\n",
        "    \n",
        "    # Unpack the inputs from our dataloader\n",
        "    b_input_ids, b_input_mask, b_labels = batch\n",
        "    \n",
        "    # Telling the model not to compute or store gradients, saving memory and speeding up validation\n",
        "    with torch.no_grad():\n",
        "      \n",
        "      # Forward pass, calculate logit predictions\n",
        "      logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)\n",
        "    \n",
        "    # Move logits and labels to CPU\n",
        "    logits = logits.detach().cpu().numpy()\n",
        "    label_ids = b_labels.to('cpu').numpy()\n",
        "\n",
        "    tmp_eval_accuracy = accuracy(logits, label_ids)\n",
        "    \n",
        "    eval_accuracy += tmp_eval_accuracy\n",
        "    nb_eval_steps += 1\n",
        "\n",
        "  print(\"Validation Accuracy: {}\".format(eval_accuracy/nb_eval_steps))"
      ],
      "execution_count": 80,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\rEpoch:   0%|          | 0/2 [00:00<?, ?it/s]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Train loss: 0.5514867920875549\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\rEpoch:  50%|█████     | 1/2 [03:31<03:31, 211.45s/it]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Validation Accuracy: 0.796875\n",
            "Train loss: 0.32061706310510635\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\rEpoch: 100%|██████████| 2/2 [07:02<00:00, 211.42s/it]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Validation Accuracy: 0.791015625\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BouebF91xtq4",
        "colab_type": "text"
      },
      "source": [
        "### Test on new question pairs to see the performance !!\n",
        "\n",
        "Upload your question pairs test data and apply the same pre-processing techniques and see how the model works!! The questions are really challenging the model. Please refer the questions I uploaded in the file \"test1.csv\" in this repository."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "90MbJVKk5GHW",
        "colab_type": "code",
        "colab": {
          "resources": {
            "http://localhost:8080/nbextensions/google.colab/files.js": {
              "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
              "ok": true,
              "headers": [
                [
                  "content-type",
                  "application/javascript"
                ]
              ],
              "status": 200,
              "status_text": ""
            }
          },
          "base_uri": "https://localhost:8080/",
          "height": 75
        },
        "outputId": "7b9f5c4b-b8f4-49fb-d75a-0c587bb08997"
      },
      "source": [
        "uploaded = files.upload()"
      ],
      "execution_count": 91,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "     <input type=\"file\" id=\"files-cb9f4b24-66e3-455f-8727-ed196b8d8048\" name=\"files[]\" multiple disabled />\n",
              "     <output id=\"result-cb9f4b24-66e3-455f-8727-ed196b8d8048\">\n",
              "      Upload widget is only available when the cell has been executed in the\n",
              "      current browser session. Please rerun this cell to enable.\n",
              "      </output>\n",
              "      <script src=\"/nbextensions/google.colab/files.js\"></script> "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Saving test1.csv to test1 (2).csv\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wZn99f1iEVR4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "test = pd.read_csv('test1 (2).csv')\n",
        "all_input_ids = prep_data(test['question1'].values, test['question2'].values)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m9mfehmbJoO_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "MAX_LEN = 128\n",
        "# Pad our input tokens\n",
        "pad_input_ids = pad_sequences(all_input_ids,\n",
        "                          maxlen=MAX_LEN, dtype=\"long\", truncating=\"post\", padding=\"post\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZIs5B2QMJu_P",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Create attention masks\n",
        "attention_masks = []\n",
        "\n",
        "# Create a mask of 1s for each token followed by 0s for padding\n",
        "for seq in pad_input_ids:\n",
        "  seq_mask = [float(i>0) for i in seq]\n",
        "  attention_masks.append(seq_mask)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eMc-urKfJyXh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "prediction_inputs = torch.tensor(pad_input_ids)\n",
        "prediction_masks = torch.tensor(attention_masks)\n",
        "  \n",
        "batch_size = 32\n",
        "\n",
        "\n",
        "prediction_data = TensorDataset(prediction_inputs, prediction_masks)\n",
        "prediction_sampler = SequentialSampler(prediction_data)\n",
        "prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-DBx9vv3J6PS",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "e983c7f4-1bf6-4102-e014-ee04cf817e34"
      },
      "source": [
        "# Prediction on test set\n",
        "\n",
        "# Put model in evaluation mode\n",
        "model.eval()\n",
        "\n",
        "# Tracking variables \n",
        "predictions = []\n",
        "\n",
        "# Predict \n",
        "for batch in prediction_dataloader:\n",
        "  # Add batch to GPU\n",
        "  batch = tuple(t.to(device) for t in batch)\n",
        "  # Unpack the inputs from our dataloader\n",
        "  b_input_ids, b_input_mask = batch\n",
        "  # Telling the model not to compute or store gradients, saving memory and speeding up prediction\n",
        "  with torch.no_grad():\n",
        "    # Forward pass, calculate logit predictions\n",
        "    logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)\n",
        "\n",
        "  # Move logits and labels to CPU\n",
        "  logits = logits.detach().cpu().numpy()\n",
        "  \n",
        "  # Store predictions and true labels\n",
        "  predictions.append(logits)\n",
        "  pred_flat = np.argmax(logits, axis=1).flatten()\n",
        "  print(pred_flat)"
      ],
      "execution_count": 96,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[0 0 1 1 1 1 1 0 0 0 1]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZMYhQrKrLNOD",
        "colab_type": "text"
      },
      "source": [
        "The algorithm works pretty good. The results are uploaded in results.csv file. It is interesting to see that it is getting a few complex classifications right. This is just a sample of 5000 points on fine tuning. With the whole dataset, this algorithm is great."
      ]
    }
  ]
}